import numpy as np
import pandas as pd
import statsmodels.api as sm

def noise():
    return np.random.normal(0,1,1000)

def fig1():
    res = []
    true_XY = 1
    for _ in range(5000):
        true_ZY = np.random.uniform(0.05,1) 
        true_UZ = np.random.uniform(0.,.1)
        true_UY = -1

        U = noise()
        Z = true_UZ*U + noise()
        X = Z  + noise()
        Y = true_XY*X + true_ZY*Z + true_UY*U + noise()
        
        dt = pd.DataFrame({'X': X,
                           'Y': Y,
                           'U': U,
                           'Z': Z})
        lm = sm.OLS(dt['Y'], dt[['X', 'Z']]).fit()
        if lm.pvalues.loc["Z"] < 0.1:
            significance=1
        else:
            significance=0
        res.append([true_XY, true_ZY, lm.params.loc['X'], lm.params.loc['Z'],  significance, true_UZ, true_UY])
    res = pd.DataFrame(res, columns = ['Truth_X', 'Truth_Z', 'X_param', 'Z_param', 'Significance', 'UZ', 'UY'])
    res['Bias_Z'] = (res['Z_param'] - res['Truth_Z'])/res['Truth_Z']
    res['Bias_X'] = (res['X_param'] - res['Truth_X'])/res['Truth_X']


    #=== Find number of times the sign was actually reversed
    a = np.sign(res[['Z_param', 'Truth_Z']])
    z = (a['Z_param']*a['Truth_Z'])
    print('Proportion of coefficients with opposite sign:', np.round(len(z[z<1])/len(z)*100, 2), '%')
    print('Proportion of non-significant coefficients:', np.round(len(res[res.Significance == 0])/len(res)*100, 2), '%')
    return res
def fig1bottom():
    res = []
    true_XY = 1
    UY = np.linspace(-2.05,0.05,8)
    r = []
    for true_UY in UY:
        res = []
        for i in range(100):
            true_ZY = np.random.uniform(0.05,1) 
            true_UZ = np.random.uniform(0.,.1)
            U = noise()
            Z = true_UZ*U + noise()
            X = Z  + noise()
            Y = true_XY*X + true_ZY*Z + true_UY*U + noise()
            
            dt = pd.DataFrame({'X': X,
                               'Y': Y,
                               'U': U,
                               'Z': Z})
            lm = sm.OLS(dt['Y'], dt[['X', 'Z']]).fit()
            if lm.pvalues.loc["Z"] < 0.1:
                significance=1
            else:
                significance=0
            res.append([true_XY, true_ZY, lm.params.loc['X'], lm.params.loc['Z'],  significance, true_UZ, true_UY])
        res = pd.DataFrame(res, columns = ['Truth_X', 'Truth_Z', 'X_param', 'Z_param', 'Significance', 'UZ', 'UY'])
        res['Bias_Z'] = (res['Z_param'] - res['Truth_Z'])/res['Truth_Z']
        res['Bias_X'] = (res['X_param'] - res['Truth_X'])/res['Truth_X']
    
    
        #=== Find number of times the sign was actually reversed
        a = np.sign(res[['Z_param', 'Truth_Z']])
        z = (a['Z_param']*a['Truth_Z'])
        r.append([true_UY, np.round(len(z[z<1])/len(z)*100, 2), np.round(len(res[res.Significance == 0])/len(res)*100, 2)])
    r = pd.DataFrame(r, columns =  ['UY', 'Neg', 'Zeros'])
    return r

####################################################################################
####################################################################################
####################################################################################
def fig2_top(l_coeff):
    res = []
    for _ in range(5000):
        U = noise()
        I = noise()
        L = noise()
        Z = U + noise()
        X = Z  + I + noise()
        C = I + L + noise()
        Y = X + Z + U + l_coeff*L + noise()
        
        dt = pd.DataFrame({'X': X,
                           'Y': Y,
                           'U': U,
                           'Z': Z,
                           'L': L,
                           'C': C,
                           'I': I})
        lm = sm.OLS(dt['Y'], dt[['X', 'Z']]).fit()
        lmB = sm.OLS(dt['Y'], dt[['X', 'Z', 'C']]).fit()
        lmL = sm.OLS(dt['Y'], dt[['X', 'Z', 'C', 'L']]).fit()

        if lmB.pvalues.loc["X"] < 0.1:
            significance=1
        else:
            significance=0
        res.append([1, lm.params.loc['X'], lmB.params.loc['X'], lmL.params.loc['X'],  significance, lm.rsquared, lmB.rsquared])
    res = pd.DataFrame(res, columns = ['Truth_X', 'X_param', 'X_biased', 'X_non_biased',  'Significance', 'R2', 'R2C'])
    res['Bias_X'] = (res['X_param'] - res['Truth_X'])/res['Truth_X']
    res['Bias_XB'] = (res['X_biased'] - res['Truth_X'])/res['Truth_X']
    res['Bias_XL'] = (res['X_non_biased'] - res['Truth_X'])/res['Truth_X']

    return res, l_coeff

def fig2_bottom():
    dx = []
    z = pd.DataFrame()
    for lc in np.linspace(0.5,5, 10):
        res_, l_coeff = fig2_top(lc)
        dx.append([lc, res_.Bias_XB.mean(), res_.Bias_XB.std(), len(res_[res_.X_biased < 0])/len(res_)])
        b = res_[['Bias_XB']]
        b[r'$\eta$'] = [lc]*len(res_)
        z = pd.concat((z, b))
    return dx, z
####################################################################################
####################################################################################
####################################################################################

def fig3():
    res = []
    res_total=[]
    th_effects = []
    for _ in range(5000):
        s_p=np.random.uniform(0.0,1.25)
        s_n=np.random.uniform(0.0,1.25)
        p_l=np.random.uniform(0.0,0.5)
        n_l=np.random.uniform(0.0,0.5)
        S = noise()
        P = s_p*S  + noise()
        N = s_n*S  + noise()
        L = p_l*P + n_l*N + noise()
        
        dt = pd.DataFrame({'S': S,
                           'P': P,
                           'N': N,
                           'L': L})
        lm = sm.OLS(dt['L'], dt[['P', 'S']]).fit()
        lmT = sm.OLS(dt['L'], dt[['S']]).fit()
        lmN = sm.OLS(dt['L'], dt[['P', 'S', 'N']]).fit()

        significance = lm.pvalues['S']
        if significance < 0.1: 
            significance = 1
        else:
            significance = 0
        res.append([lm.params['P'], lm.params['S'], lmT.params['S'], lmN.params['S'], significance])
        #lm = sm.OLS(dt['L'], sm.tools.add_constant(dt[['S']])).fit()
        lm = sm.OLS(dt['L'], dt[['S']]).fit()

        res_total.append([lm.params['S']])

        th_effects.append([p_l, s_n*n_l, s_n*n_l+s_p*p_l])
    res = pd.DataFrame(res, columns = ['P_param', 'S_param', 'S_total', 'S_null', 'Significance'])
    res_total = pd.DataFrame(res_total, columns = ['S_param'])
    th_effects = pd.DataFrame(th_effects, columns = ['P_param', 'S_partial', 'S_total'])
    a = pd.concat((res['S_param'], res_total['S_param']), axis = 1)
    a.columns=['Partial', 'Total']
    print('Fraction of time when the partial and total effect have opposite sign:', len(a[a.Partial.apply(np.sign) != a.Total.apply(np.sign)])/len(a))
    print('Fraction of time in which the partial effect was not significant:', 1- np.mean(res['Significance']))

    return res